// ///////////////////////////// MIT License //////////////////////////////////// //
//                                                                                //
// Copyright (c) 2010 David Zsolt Manrique                                        //
//                    david.zsolt.manrique@gmail.com                              //
//                                                                                //
// Permission is hereby granted, free of charge, to any person obtaining a copy   //
// of this software and associated documentation files (the "Software"), to deal  //
// in the Software without restriction, including without limitation the rights   //
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell      //
// copies of the Software, and to permit persons to whom the Software is          //
// furnished to do so, subject to the following conditions:                       //
//                                                                                //
// The above copyright notice and this permission notice shall be included in     //
// all copies or substantial portions of the Software.                            //
//                                                                                //
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR     //
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,       //
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE    //
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER         //
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,  //
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN      //
// THE SOFTWARE.                                                                  //
//                                                                                //
// ////////////////////////////////////////////////////////////////////////////// //

#include <iostream>
#include <ostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include <complex>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <fstream>
#include <string>
#include <algorithm>
#include <iomanip>
#include <vector>
#include <list>
#include <set>
#include <cstdarg>
#include <stdexcept>

#include <Eigen/Eigen>
#include <cmdioutil.h>

void print_pattern(std::ostream & os,Eigen::MatrixXcd M,int dn = 1,int dm = 1,double tol = 1.0e-5)
{
    
    os << " --- matrix pattern --- " << std::endl;
    os << " block (rows x cols) : " << dn << "x" << dm << "  pattern : X if block norm > " << std::scientific << tol << std::endl;

    int n = M.rows();
    int m = M.cols();
    
    for(int i = 0; i<n; i+=dn)
    {
        for(int j = 0; j<m; j+=dm)
        {
            if(M.block(i,j,std::min(n-i,dn),std::min(m-j,dm)).norm() > tol)  os << 'X'; else os << '.';
        }

        os << std::endl;
    
    }
    os << " --- matrix pattern end --- " << std::endl;
    
    os << std::endl;
    
}


struct SH
{
    Eigen::MatrixXcd S,H;
    double ef;
    void resize(int r,int c) { S.resize(r,c); H.resize(r,c); }
    void symetrize() { S = 0.5*(S+S.adjoint().eval()); H = 0.5*(H+H.adjoint().eval()); }
    double antisym_norm() const { return 0.5*( (S-S.adjoint().eval()).norm()+(H-H.adjoint().eval()).norm() );  } 
};

bool approxequal(const SH & sh1,const SH & sh2, double tol = 1.0e-4)
{
    return ( ( (sh1.S-sh2.S).norm() < tol) && ( (sh1.H-sh2.H).norm() < tol) );
}


struct blockSH : SH
{
    std::vector<std::pair<int,int> > orbital_range;
    int n;
    
    void reindex(int & ii,int & jj)
    {
        ii = ii < 0 ? ii + n : ii - 1;
        jj = jj < 0 ? jj + n : jj - 1;

        //if(ii==jj) throw std::runtime_error("Empty index range!");

        if(ii>jj) std::swap(ii,jj);
        
    }    
    
    void get_part(int ii,int jj,SH & sh) const
    {
        int i1 = orbital_range[ii].first;
        int n1 = orbital_range[jj].first+orbital_range[jj].second-i1;
        
        sh.S = S.block(i1,i1,n1,n1);
        sh.H = H.block(i1,i1,n1,n1);
    }

    void set_part(int ii,int jj,const SH & sh) 
    {
        int i1 = orbital_range[ii].first;
        int n1 = orbital_range[jj].first+orbital_range[jj].second-i1;
        
        S.block(i1,i1,n1,n1) = sh.S;
        H.block(i1,i1,n1,n1) = sh.H;
    }

    void get_part_2x2(int ii,int jj,SH & sh11,SH & sh12,SH & sh21, SH & sh22) const
    {
        
        int i1 = orbital_range[ii].first;
        int n1 = orbital_range[jj].first+orbital_range[jj].second-i1;
        
        sh11.S = S.block(i1,i1,n1/2,n1/2);
        sh11.H = H.block(i1,i1,n1/2,n1/2);

        sh12.S = S.block(i1,i1+n1/2,n1/2,n1/2);
        sh12.H = H.block(i1,i1+n1/2,n1/2,n1/2);

        sh21.S = S.block(i1+n1/2,i1,n1/2,n1/2);
        sh21.H = H.block(i1+n1/2,i1,n1/2,n1/2);

        sh22.S = S.block(i1+n1/2,i1+n1/2,n1/2,n1/2);
        sh22.H = H.block(i1+n1/2,i1+n1/2,n1/2,n1/2);
        
    }

    void get_part(int ii1,int jj1,int ii2,int jj2,SH & sh) const
    {
        int i1 = orbital_range[ii1].first;
        int n1 = orbital_range[jj1].first+orbital_range[jj1].second-i1;

        int i2 = orbital_range[ii2].first;
        int n2 = orbital_range[jj2].first+orbital_range[jj2].second-i2;
        
        sh.S = S.block(i1,i2,n1,n2);
        sh.H = H.block(i1,i2,n1,n2);
    }

    void set_part(int ii1,int jj1,int ii2,int jj2,const SH & sh) 
    {
        int i1 = orbital_range[ii1].first;
        int n1 = orbital_range[jj1].first+orbital_range[jj1].second-i1;

        int i2 = orbital_range[ii2].first;
        int n2 = orbital_range[jj2].first+orbital_range[jj2].second-i2;
        
        S.block(i1,i2,n1,n2) = sh.S;
        H.block(i1,i2,n1,n2) = sh.H;
    }

    void get_atom(int I,int J,SH & sh) const
    {
        int i,j,ni,nj;

        i = orbital_range[I].first;
        ni = orbital_range[I].second;
        j = orbital_range[J].first;
        nj = orbital_range[J].second;
        sh.S = S.block(i,j,ni,nj);
        sh.H = H.block(i,j,ni,nj);
        
        sh.ef = ef;
    }

    void print_pattern(std::ostream & os,double tol = 1e-6) const
    {
        os << " --- atom interaction pattern of H and S --- " << std::endl;
        os << " pattern : X if block norm > " << std::scientific << tol << std::endl;
        os << " left H right S " << std::endl;
        int rowcount = 0;
        for(std::vector<std::pair<int,int> >::const_iterator it = orbital_range.begin(); it != orbital_range.end(); it++) 
        {
//            rowcount++; if(rowcount==10) { os << std::endl; rowcount = 0; }
            int colcount = 0;
            for(std::vector<std::pair<int,int> >::const_iterator jt = orbital_range.begin(); jt != orbital_range.end(); jt++)
            {
//                colcount++; if(colcount==10) { os << ' '; colcount = 0; }
                
                int i,j,ni,nj;
                i = it->first;
                ni = it->second;
                j = jt->first;
                nj = jt->second;
                if(H.block(i,j,ni,nj).norm() > tol)  os << 'X'; else os << '.';
            }

            os << "      ";
            for(std::vector<std::pair<int,int> >::const_iterator jt = orbital_range.begin(); jt != orbital_range.end(); jt++)
            {
                int i,j,ni,nj;
                i = it->first;
                ni = it->second;
                j = jt->first;
                nj = jt->second;
                if(S.block(i,j,ni,nj).norm() > tol)  os << 'X'; else os << '.';
            }
            
            
            os << std::endl;
        }
        os << " --- atom interaction pattern of H and S end --- " << std::endl;
        os << std::endl << std::endl;
    }
};

void load(std::string fn,blockSH & sh)
{
    std::complex<double> ii(0,1.0);
    
    std::vector<std::string> lines;
    io::load(fn,"#",lines);

    sh.ef = 0.0;
    std::vector<std::string> efermi_block;
    io::parse_block("<fermi-energy>","</fermi-energy>",lines,efermi_block);
    for(std::vector<std::string>::iterator it = efermi_block.begin(); it != efermi_block.end(); it++)
    {
        std::istringstream iss(*it);
        if (!(iss >> sh.ef)) throw std::runtime_error("Not able to parse fermi energy block in the input file!");
        break;
    }

    std::vector<std::string> orbital_block;
    if(!io::parse_block("<index-orbital>","</index-orbital>",lines,orbital_block)) throw std::runtime_error("No or inproper index-orbital block in the input file!");
    if(!orbital_block.size()) throw std::runtime_error("No or inproper index-orbital block in the input file!");
    
    sh.orbital_range.clear();
    int index,orbital;
    int inc = 0;
    int orbital_index = 0;
    for(std::vector<std::string>::iterator it = orbital_block.begin(); it != orbital_block.end(); it++)
    if(util::match("% %",*it,index,orbital))
    {
            std::pair<int,int> p(orbital_index,orbital);
            orbital_index += orbital;
            sh.orbital_range.push_back(p);
            if(++inc != index) throw std::runtime_error("Index has to be in order from 1 in the index-orbital block!");
    }
    int n = orbital_index;
    sh.H.resize(n,n); sh.H.setZero(); sh.S = sh.H; sh.S.setIdentity();
    sh.n = sh.orbital_range.size();

    double re,im;
    int i,j;
    std::vector<std::string> overlap_block;
    if(!io::parse_block("<overlap>","</overlap>",lines,overlap_block)) throw std::runtime_error("There is no overlap matrix block or it is inproper in the input file!");
    for(std::vector<std::string>::iterator it = overlap_block.begin(); it != overlap_block.end(); it++)
    if(util::match("% % % % ",*it,i,j,re,im)) { if(i<0) i+=n+1; if(j<0) j+=n+1; if((i<n+1)&&(j<n+1)&&(i>0)&&(j>0)) sh.S(i-1,j-1)=re+ii*im; else throw std::runtime_error("Inconsistent overlap matrix index with index orbital block!");  }
    else if(util::match("% % % ",*it,i,j,re)) { if(i<0) i+=n+1; if(j<0) j+=n+1; if((i<n+1)&&(j<n+1)&&(i>0)&&(j>0)) sh.S(i-1,j-1)=re+ii*0.0; else throw std::runtime_error("Inconsistent overlap matrix index with index orbital block!"); }
    
    std::vector<std::string> hamiltonian_block;
    if(!io::parse_block("<hamiltonian>","</hamiltonian>",lines,hamiltonian_block)) throw std::runtime_error("There is no hamiltonian block or it is inproper in the input file!");
    for(std::vector<std::string>::iterator it = hamiltonian_block.begin(); it != hamiltonian_block.end(); it++)
    if(util::match("% % % % ",*it,i,j,re,im)) { if(i<0) i+=n+1; if(j<0) j+=n+1; if((i<n+1)&&(j<n+1)&&(i>0)&&(j>0)) sh.H(i-1,j-1)=re+ii*im; else throw std::runtime_error("Inconsistent hamiltonian index with index orbital block!"); }
    else if(util::match("% % % ",*it,i,j,re)) { if(i<0) i+=n+1; if(j<0) j+=n+1; if((i<n+1)&&(j<n+1)&&(i>0)&&(j>0)) sh.H(i-1,j-1)=re+ii*0.0; else throw std::runtime_error("Inconsistent hamiltonian index with index orbital block!"); }
    if(!hamiltonian_block.size()) throw std::runtime_error("There is no hamiltonian block or it is inproper in the input file!");

}

void save(std::string fn,blockSH & sh)
{
    std::complex<double> ii(0,1.0);
    std::ofstream os(fn.c_str());
    
    os << "<fermi-energy>" << std::endl;
    os << sh.ef << std::endl;
    os << "</fermi-energy>" << std::endl;
    os << std::endl;
    os << "<index-orbital>" << std::endl;
    for(int i=0; i<sh.orbital_range.size(); i++)    
    os << i+1 << " "  << sh.orbital_range[i].second << std::endl;
    os << "</index-orbital>" << std::endl;
    os << std::endl;
    os << "<overlap>" << std::endl;
    for(int i=0;i<sh.S.rows();i++)
    for(int j=0;j<sh.S.cols();j++)
    if(std::abs(sh.S(i,j))> 1e-7 )
    {
        os << "   " << i+1 << "    " << j+1 << "     "<< std::scientific << std::real(sh.S(i,j)) << "    " << std::scientific << std::imag(sh.S(i,j)) << std::endl;
    }
    os << "</overlap>" << std::endl;
    os << std::endl;
    os << "<hamiltonian>" << std::endl;
    for(int i=0;i<sh.H.rows();i++)
    for(int j=0;j<sh.H.cols();j++)
    if(std::abs(sh.H(i,j))> 1e-7 )
    {
        os << "   " << i+1 << "    " << j+1 << "     "<< std::scientific << std::real(sh.H(i,j)) << "    " << std::scientific << std::imag(sh.H(i,j)) << std::endl;
    }
    os << "</hamiltonian>" << std::endl;

    os.close();
    
}


void load_gollum(std::string fn,blockSH & sh)
{
    std::ifstream is(fn.c_str());
    std::complex<double> ii(0,1.0);
    std::string dump;
    if(!(is >> dump >> dump >> dump >> dump >> dump >> dump >> dump)) throw std::runtime_error("Wrong gollum input format: Syntax error near the beginning!");
    if(!(is >> dump >> dump >> dump >> dump >> dump >> dump >> sh.ef )) throw std::runtime_error("Wrong gollum input format: Syntax error near the beginning!");
    if(!(is >> dump >> dump >> dump >> dump >> dump >> dump )) throw std::runtime_error("Wrong gollum input format: Syntax errour near the beginning!");
    int rows = 0;
    int cols = 0; 
    if(!(is >> dump >> dump >> rows )) throw std::runtime_error("Wrong gollum input format: Syntax error near the beginning");
    if(!(is >> dump >> dump >> cols )) throw std::runtime_error("Wrong gollum input format: Syntax error near the beginning");
    int index;
    int pindex = 1;
    int pi = 0;
    int n = rows;    
    sh.orbital_range.clear();
    for(int j = 0; j < n; j++)
    {
        if(!(is >> index >> dump >> dump >> dump >> dump )) throw std::runtime_error("Wrong gollum input format: Syntax error in the orbital block!");
        if(index != pindex)
        {
            std::pair<int,int> p(pi,j-pi);
            sh.orbital_range.push_back(p);
            pindex = index;
            pi = j;
        }
    }
    std::pair<int,int> p(pi,n-pi);
    sh.orbital_range.push_back(p);
    sh.n = sh.orbital_range.size();

    if(!(is >> dump >> dump >> dump >> dump >> dump >> dump )) throw std::runtime_error("Wrong gollum input format: Syntax error near k-block!");
    //int rows,cols;
    if(!(is >> dump >> dump >> rows  )) throw std::runtime_error("Wrong gollum input format: Syntax error near k-block!");
    if(!(is >> dump >> dump >> cols  )) throw std::runtime_error("Wrong gollum input format: Syntax error near k-block!");
    for(int j = 0; j < rows; j++)
        if(!(is >> dump >> dump >> dump  )) throw std::runtime_error("Wrong gollum input format: Syntax error in k-block!");
    if(rows > 1) std::cerr << "Warning: Gollum input delivers multiple k point matrix. This load does not handle that, instead try separate xml inputs for each k-point." << std::endl;
    if(!(is >> dump >> dump >> dump >> dump >> dump >> dump  )) throw std::runtime_error("Wrong gollum input format: Syntax error near matrix block!");
    //int rows,cols;
    if(!(is >> dump >> dump >> rows  )) throw std::runtime_error("Wrong gollum input format: Syntax error near matrix block!");
    if(!(is >> dump >> dump >> cols  )) throw std::runtime_error("Wrong gollum input format: Syntax error near matrix block!");
    sh.H.resize(n,n); sh.H.setZero(); sh.S = sh.H;
    for(int k = 0; k < rows; k++)
    {
        int i,j;
        double res,ims,reh,imh;
        if(!(is >> dump >> i >> j >> res >> ims >> reh >> imh  )) throw std::runtime_error("Wrong gollum input format: Syntax error in matrix block!");
        sh.H(i-1,j-1)=reh+ii*imh;
        sh.S(i-1,j-1)=res+ii*ims;

    }
    is.close();
    
}

void interpret_chain(std::string str,blockSH & sh)
{
    std::complex<double> ii(0,1.0);
    
    std::vector<std::string> items;
    util::tokenize(str,items,",_=: ");
    
    if(items.size()%2 != 1) throw std::runtime_error("Site-coupling sequence has to be odd: e-g-e-g-e -> 5");

    int n = (items.size()+1)/2;
    for(int i = 0; i < n; i++) sh.orbital_range.push_back(std::pair<int,int>(i,1));

    sh.H.resize(n,n); sh.H.setZero(); sh.S = sh.H; sh.S.setIdentity();
    sh.n = sh.orbital_range.size();

    double re = 0.0;
    double im = 0.0;
    for(int i = 0; i < n; i++)   { std::istringstream iss(items[2*i]);   iss >> re; sh.H(i,i)=re+ii*im;         }
    for(int i = 0; i < n-1; i++) { std::istringstream iss(items[2*i+1]); iss >> re; sh.H(i,i+1)=re+ii*im;  sh.H(i+1,i)=sh.H(i,i+1); }

    sh.ef = 0.0;
}

void load_dynamicalmatrix(std::string fn,blockSH & sh)
{
    std::complex<double> ii(0,1.0);
    
    std::vector<std::string> lines;
    io::load(fn,"#",lines);

    int dim = 1;
    std::vector<std::string> dim_block;
    io::parse_block("<dim>","</dim>",lines,dim_block);
    for(std::vector<std::string>::iterator it = dim_block.begin(); it != dim_block.end(); it++)
    {
        std::istringstream iss(*it);
        if (!(iss >> dim)) throw std::runtime_error("Not able to parse dim block in the input file!");
        break;
    }


    int atomnum = 1;
    std::vector<std::string> atomnum_block;
    io::parse_block("<atom-number>","</atom-number>",lines,atomnum_block);
    for(std::vector<std::string>::iterator it = atomnum_block.begin(); it != atomnum_block.end(); it++)
    {
        std::istringstream iss(*it);
        if (!(iss >> atomnum)) throw std::runtime_error("Not able to parse atom number block in the input file!");
        break;
    }

    int n = atomnum*dim;

    sh.ef = 0.0;
    sh.orbital_range.clear();
    for (int i = 0; i < atomnum; i++)
    {
            std::pair<int,int> p(i*dim,dim);
            sh.orbital_range.push_back(p);
    }
    
    sh.H.resize(n,n); sh.H.setZero(); sh.S = sh.H; sh.S.setIdentity();
    sh.n = sh.orbital_range.size();

    double re,im;
    int i,j;

/*
    std::vector<std::string> overlap_block;
    if(!io::parse_block("<overlap>","</overlap>",lines,overlap_block)) throw std::runtime_error("There is no overlap matrix block or it is inproper in the input file!");
    for(std::vector<std::string>::iterator it = overlap_block.begin(); it != overlap_block.end(); it++)
    if(util::match("% % % % ",*it,i,j,re,im)) { if(i<0) i+=n+1; if(j<0) j+=n+1; if((i<n+1)&&(j<n+1)&&(i>0)&&(j>0)) sh.S(i-1,j-1)=re+ii*im; else throw std::runtime_error("Inconsistent overlap matrix index with index orbital block!");  }
    else if(util::match("% % % ",*it,i,j,re)) { if(i<0) i+=n+1; if(j<0) j+=n+1; if((i<n+1)&&(j<n+1)&&(i>0)&&(j>0)) sh.S(i-1,j-1)=re+ii*0.0; else throw std::runtime_error("Inconsistent overlap matrix index with index orbital block!"); }
*/
    std::vector<std::string> dynamical_block;
    if(!io::parse_block("<dynamical-matrix>","</dynamical-matrix>",lines,dynamical_block)) throw std::runtime_error("There is no dynamical matrix block or it is inproper in the input file!");
    for(std::vector<std::string>::iterator it = dynamical_block.begin(); it != dynamical_block.end(); it++)
    if(util::match("% % % % ",*it,i,j,re,im)) { if(i<0) i+=n+1; if(j<0) j+=n+1; if((i<n+1)&&(j<n+1)&&(i>0)&&(j>0)) sh.H(i-1,j-1)=re+ii*im; else throw std::runtime_error("Inconsistent dynamical matrix index with atom-numer or dim block!");  }
    else if(util::match("% % % ",*it,i,j,re)) { if(i<0) i+=n+1; if(j<0) j+=n+1; if((i<n+1)&&(j<n+1)&&(i>0)&&(j>0)) sh.H(i-1,j-1)=re+ii*0.0; else throw std::runtime_error("Inconsistent dynamical matrix index with atom-number or dim block!"); }
    if(!dynamical_block.size()) throw std::runtime_error("There is no dynamical block or it is inproper in the input file!");

}

void load_sqmat(std::string fn,blockSH & sh)
{                                                                                                                                                             
    std::complex<double> ii(0,1.0);                                                                                                                           
    std::ifstream is(fn.c_str());                                                                                                                             
    
    int dim = 1;
    int atomnum = 1;                                                                                                                                          

    is >> dim;                                                                                                                                                
    is >> atomnum;                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
    int n = atomnum*dim;                                                        

    sh.ef = 0.0;
    sh.orbital_range.clear();
    for (int i = 0; i < atomnum; i++)
    {
            std::pair<int,int> p(i*dim,dim);
            sh.orbital_range.push_back(p);
    }
    
    sh.H.resize(n,n); sh.H.setZero(); sh.S = sh.H; sh.S.setIdentity();
    sh.n = sh.orbital_range.size();

    double re;                                                                  
                                                                                
    for(int i = 0; i<n;i++)                                                     
    for(int j = 0; j<n;j++)                                                     
    {                                                                           
        if (!(is >> re)) throw std::runtime_error("Not able to parse square matrix elem!");
        sh.H(i,j)=re+ii*0.0;
    }
    is.close(); 
}
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          